{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 50000 samples, validate on 10000 samples\n",
      "Epoch 1/5\n",
      "50000/50000 [==============================] - 65s 1ms/sample - loss: 1.8470 - accuracy: 0.3300 - val_loss: 1.6375 - val_accuracy: 0.4177\n",
      "Epoch 2/5\n",
      "50000/50000 [==============================] - 53s 1ms/sample - loss: 1.6002 - accuracy: 0.4265 - val_loss: 1.5135 - val_accuracy: 0.4611\n",
      "Epoch 3/5\n",
      "50000/50000 [==============================] - 50s 1ms/sample - loss: 1.4906 - accuracy: 0.4657 - val_loss: 1.4620 - val_accuracy: 0.4753\n",
      "Epoch 4/5\n",
      "50000/50000 [==============================] - 48s 952us/sample - loss: 1.4306 - accuracy: 0.4908 - val_loss: 1.4101 - val_accuracy: 0.4985\n",
      "Epoch 5/5\n",
      "50000/50000 [==============================] - 47s 947us/sample - loss: 1.3803 - accuracy: 0.5077 - val_loss: 1.3687 - val_accuracy: 0.5145\n"
     ]
    }
   ],
   "source": [
    "'''将二进制格式的MNIST数据集转成.jpg图片格式并保存，图片标签包含在图片名中'''\n",
    "from tqdm import tqdm_notebook\n",
    "import numpy as np\n",
    "import cv2\n",
    "import os\n",
    "\n",
    "def save_mnist_to_jpg(mnist_image_file, mnist_label_file, save_dir):\n",
    "    if 'train' in os.path.basename(mnist_image_file):\n",
    "        num_file = 60000\n",
    "        prefix = 'train'\n",
    "    else:\n",
    "        num_file = 10000\n",
    "        prefix = 'test'\n",
    "    with open(mnist_image_file, 'rb') as f1:\n",
    "        image_file = f1.read()\n",
    "    with open(mnist_label_file, 'rb') as f2:\n",
    "        label_file = f2.read()\n",
    "    image_file = image_file[16:]\n",
    "    label_file = label_file[8:]\n",
    "    for i in tqdm_notebook(range(num_file)):\n",
    "        label = int(label_file[i])\n",
    "        image_list = [int(item) for item in image_file[i*784:i*784+784]]\n",
    "        image_np = np.array(image_list, dtype=np.uint8).reshape(28,28,1)\n",
    "        save_name = os.path.join(save_dir, '{}_{}_{}.jpg'.format(prefix, i, label))\n",
    "        cv2.imwrite(save_name, image_np)\n",
    "        print('{} ==> {}_{}_{}.jpg'.format(i, prefix, i, label))\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    train_image_file = 'D:/datasets/mnist/train-images.idx3-ubyte'\n",
    "    train_label_file = 'D:/datasets/mnist/train-labels.idx1-ubyte'\n",
    "    test_image_file = 'D:/datasets/mnist/t10k-images.idx3-ubyte'\n",
    "    test_label_file = 'D:/datasets/mnist/t10k-labels.idx1-ubyte'\n",
    "\n",
    "    save_train_dir = 'D:/datasets/mnist/train_images/'\n",
    "    save_test_dir ='D:/datasets/mnist/test_images/'\n",
    "\n",
    "    if not os.path.exists(save_train_dir):\n",
    "        os.makedirs(save_train_dir)\n",
    "    if not os.path.exists(save_test_dir):\n",
    "        os.makedirs(save_test_dir)\n",
    "\n",
    "    save_mnist_to_jpg(train_image_file, train_label_file, save_train_dir)\n",
    "    save_mnist_to_jpg(test_image_file, test_label_file, save_test_dir)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
